{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import logging\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm.autonotebook import tqdm\n",
    "\n",
    "import mindspore\n",
    "from mindspore import nn\n",
    "from mindspore import Parameter, Tensor\n",
    "from mindspore.nn import AdamWeightDecay as AdamW\n",
    "import mindspore.dataset as ds\n",
    "from mindspore.dataset import GeneratorDataset, transforms\n",
    "from mindspore.dataset.text import Vocab as msVocab\n",
    "\n",
    "from mindnlp.modules import CRF\n",
    "from mindnlp.models import BertConfig, BertModel\n",
    "from mindnlp.transforms import BertTokenizer\n",
    "from mindnlp.engine import Trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    mindspore.set_seed(seed)\n",
    "    mindspore.dataset.config.set_seed(seed)\n",
    "\n",
    "# 读取文本，返回词典，索引表，句子，标签\n",
    "def read_data(path):\n",
    "    sentences = []\n",
    "    labels = []\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        sent = []\n",
    "        label = []\n",
    "        for line in f:\n",
    "            parts = line.split()\n",
    "            if len(parts) == 0:\n",
    "                if len(sent) != 0:\n",
    "                    sentences.append(sent)\n",
    "                    labels.append(label)\n",
    "                sent = []\n",
    "                label = []\n",
    "            else:\n",
    "                sent.append(parts[0])\n",
    "                label.append(parts[-1])\n",
    "                \n",
    "    return (sentences, labels)\n",
    "\n",
    "def read_vocab(path):\n",
    "    vocab_list = []\n",
    "    with open(path, 'r', encoding='utf-8') as f:\n",
    "        for word in f:\n",
    "            vocab_list.append(word.strip())\n",
    "    return vocab_list\n",
    "\n",
    "def get_entity(decode):\n",
    "    starting=False\n",
    "    p_ans=[]\n",
    "    for i,label in enumerate(decode):\n",
    "        if label > 0:\n",
    "            if label%2==1:\n",
    "                starting=True\n",
    "                p_ans.append(([i],labels_text_mp[label//2]))\n",
    "            elif starting:\n",
    "                p_ans[-1][0].append(i)\n",
    "        else:\n",
    "            starting=False\n",
    "    return p_ans\n",
    "\n",
    "# 处理数据\n",
    "class Feature(object):\n",
    "    def __init__(self,sent, label):\n",
    "        self.sent = sent\n",
    "        label = [LABEL_MAP[c] for c in label]\n",
    "        self.token_ids = list(tokenizer(' '.join(sent)))\n",
    "        self.seq_length = len(self.token_ids) if len(self.token_ids) - 2 < max_Len else max_Len + 2\n",
    "        offset = tokenizer.encode(' '.join(sent)).offsets\n",
    "        self.labels = self.get_labels(offset, label)\n",
    "        self.labels = [0] + self.labels[:max_Len] + [0]\n",
    "        self.labels = self.labels + [0]*(max_Len - len(self.labels) + 2)\n",
    "        \n",
    "        self.token_ids = [101] + self.token_ids[1:-1][:max_Len] + [102]\n",
    "        self.token_ids = self.token_ids + [0]*(max_Len - len(self.token_ids) + 2)\n",
    "        self.entity = get_entity(self.labels)\n",
    "\n",
    "    def get_labels(self, offset_mapping, label):\n",
    "        sent_len, count, index = 0, 0, 0\n",
    "        label_new = []\n",
    "        for l, r in offset_mapping:\n",
    "            if l != 0 or r != 0:\n",
    "                if count == sent_len:\n",
    "                    sent_len += len(self.sent[index])\n",
    "                    index += 1\n",
    "                count += r - l\n",
    "                label_new.append(label[index-1])\n",
    "                \n",
    "        return label_new\n",
    "\n",
    "class GetDatasetGenerator:\n",
    "    def __init__(self, path):\n",
    "        data = read_data(path)\n",
    "        self.features = [Feature(data[0][i], data[1][i]) for i in range(len(data[0]))]\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.features)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        feature = self.features[index]\n",
    "        token_ids = feature.token_ids\n",
    "        labels = feature.labels\n",
    "        \n",
    "        return token_ids, feature.seq_length, labels\n",
    "\n",
    "def process_dataset(source, batch_size, shuffle):\n",
    "    dataset = ds.GeneratorDataset(source, [\"ids\", \"seq_length\", \"labels\"], shuffle=shuffle)\n",
    "    dataset = dataset.batch(batch_size=batch_size)\n",
    "    \n",
    "    return dataset\n",
    "\n",
    "def debug_dataset(dataset):\n",
    "    dataset = dataset.batch(batch_size=16)\n",
    "    for data in dataset.create_dict_iterator():\n",
    "        print(data[\"data\"].shape, data[\"label\"].shape)\n",
    "        break\n",
    "        \n",
    "def get_metric(P_ans, valid):\n",
    "    predict_score = 0 # 预测正确个数\n",
    "    predict_number = 0 # 预测结果个数\n",
    "    totol_number = 0 # 标签个数\n",
    "    for i in range(len(P_ans)):\n",
    "        predict_number += len(P_ans[i])\n",
    "        totol_number += len(valid.features[i].entity)\n",
    "        pred_true = [x for x in valid.features[i].entity if x in P_ans[i]]\n",
    "        predict_score += len(pred_true)\n",
    "    P = predict_score/predict_number if predict_number>0 else 0.\n",
    "    R = predict_score/totol_number if totol_number>0 else 0.\n",
    "    f1=(2*P*R)/(P+R) if (P+R)>0 else 0.\n",
    "    print(f'f1 = {f1}， P(准确率) = {P}, R(召回率) = {R}')\n",
    "    \n",
    "def get_optimizer(model):\n",
    "    param_optimizer = list(model.parameters_and_names())\n",
    "\n",
    "    no_decay = ['bias', 'layer_norm.bias', 'layer_norm.weight']\n",
    "    crf_p = [n for n, p in param_optimizer if str(n).find('crf') != -1]\n",
    "\n",
    "    optimizer_grouped_parameters = [\n",
    "            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and  n not in crf_p], 'weight_decay': 0.8},\n",
    "            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n not in crf_p], 'weight_decay': 0.0},\n",
    "            {'params': [p for n, p in param_optimizer if n in crf_p], 'lr': 3e-3,'weight_decay': 0.8},\n",
    "            ]\n",
    "    optimizer = AdamW(optimizer_grouped_parameters, learning_rate=3e-5, eps=1e-8) # 学习率不宜过大，不然预测结果可能都是0\n",
    "\n",
    "    return optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Bert_LSTM_CRF(nn.Cell):\n",
    "    def __init__(self, num_labels):\n",
    "        super().__init__()\n",
    "        self.num_labels = num_labels\n",
    "        config = BertConfig.from_pretrained('bert-base-uncased')\n",
    "        self.bert_model = BertModel.from_pretrained('bert-base-uncased', config=config)\n",
    "        self.bilstm = nn.LSTM(config.hidden_size, config.hidden_size//2, batch_first=True, bidirectional=True)\n",
    "        self.crf_hidden_fc = nn.Dense(config.hidden_size, self.num_labels)\n",
    "        self.crf = CRF(self.num_labels, batch_first=True, reduction='mean')\n",
    "\n",
    "    def construct(self, ids, seq_length=None, labels=None):\n",
    "        attention_mask = (ids > 0)\n",
    "        output = self.bert_model(input_ids=ids, attention_mask=attention_mask)\n",
    "        lstm_feat, _ = self.bilstm(output[0])\n",
    "        emissions = self.crf_hidden_fc(lstm_feat)\n",
    "        loss_crf = self.crf(emissions, tags=labels, seq_length=seq_length)\n",
    "\n",
    "        return loss_crf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "seed_everything(42)\n",
    "max_Len = 113\n",
    "Entity = ['PER', 'LOC', 'ORG', 'MISC']\n",
    "labels_text_mp={k:v for k,v in enumerate(Entity)}\n",
    "LABEL_MAP = {'O': 0}\n",
    "for i, e in enumerate(Entity):\n",
    "    LABEL_MAP[f'B-{e}'] = 2 * (i+1) - 1\n",
    "    LABEL_MAP[f'I-{e}'] = 2 * (i+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "train = GetDatasetGenerator('conll2003/train.txt')\n",
    "test = GetDatasetGenerator('conll2003/test.txt')\n",
    "dev = GetDatasetGenerator('conll2003/valid.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(60169:140459444134528,MainProcess):2023-06-07-14:40:53.875.833 [mindspore/lib/python3.7/site-packages/mindnlp/abc/models/pretrained_model.py:458] The following parameters in checkpoint files are not loaded:\n",
      "['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.layer_norm.gamma', 'cls.predictions.transform.layer_norm.beta', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n"
     ]
    }
   ],
   "source": [
    "epochs = 3\n",
    "batch_size = 16\n",
    "dataset_train = process_dataset(train, batch_size=batch_size, shuffle=False)\n",
    "model = Bert_LSTM_CRF(num_labels=len(Entity)*2+1)\n",
    "optimizer = get_optimizer(model)\n",
    "trainer = Trainer(network=model, train_dataset=dataset_train, optimizer=optimizer, epochs=epochs,\n",
    "                  jit=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|███████████████████████████| 937/937 [07:40<00:00,  2.04it/s, loss=3.2058835]\n",
      "Epoch 1: 100%|██████████████████████████| 937/937 [07:16<00:00,  2.15it/s, loss=0.96710545]\n",
      "Epoch 2: 100%|██████████████████████████| 937/937 [07:17<00:00,  2.14it/s, loss=0.65512425]\n"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 937/937 [08:14<00:00,  1.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.9731468570087394， P(准确率) = 0.970961787475549, R(召回率) = 0.9753417833619306\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 预测：train\n",
    "dataset_train = process_dataset(train, batch_size=batch_size, shuffle=False)\n",
    "size = dataset_train.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_train.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "\n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 217/217 [01:56<00:00,  1.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.9254796965640338， P(准确率) = 0.9212659633536924, R(召回率) = 0.9297321528633867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 预测：dev\n",
    "dataset_dev = process_dataset(dev, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "size = dataset_dev.get_dataset_size()\n",
    "steps = size\n",
    "decodes=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_dev.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)#.asnumpy()\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes.extend(list(decode))\n",
    "        t.update(1)\n",
    "v_pred = [get_entity(x) for x in decodes]\n",
    "get_metric(v_pred, dev)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████| 231/231 [02:04<00:00,  1.86it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f1 = 0.8855892096545197， P(准确率) = 0.8822489391796322, R(召回率) = 0.8889548693586699\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# 预测：test\n",
    "dataset_test = process_dataset(test, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "size = dataset_test.get_dataset_size()\n",
    "steps = size\n",
    "decodes_pred=[]\n",
    "model.set_train(False)\n",
    "with tqdm(total=steps) as t:\n",
    "    for batch, (token_ids, seq_length, labels) in enumerate(dataset_test.create_tuple_iterator()):\n",
    "        score, history = model(token_ids, seq_length=seq_length)\n",
    "        best_tags = model.crf.post_decode(score, history, seq_length)\n",
    "        decode = [[y.asnumpy().item() for y in x] for x in best_tags]\n",
    "        decodes_pred.extend(list(decode))\n",
    "        t.update(1)\n",
    "        \n",
    "\n",
    "pred = [get_entity(x) for x in decodes_pred]\n",
    "get_metric(pred, test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "vscode": {
   "interpreter": {
    "hash": "a62cb8bb4abcff3256df5ab1881dc7c3e7803473070698df3ff917df10adcce5"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
